module ModuleBEGS
    
    use GLOBALS
    use toolbox
    
contains
    
    subroutine Compute_CE

        implicit none

        ! Local variable declarations
        real(8) :: C_aux(NS), C_aux_new(NS), CV_C_aux(NS), CE(NS), err, anp_vec(1)
        integer :: iter, ix_np, is, is_np, tt
        integer  :: aind_np(NS), aind_np_vec(1), SSind(NS), Sind_np(NS)
        real(8) :: Cagg_TR(T_Tr), Cagg_init, Hagg_TR(T_Tr), Hagg_init

        real(8) :: da(NS), u_aux(NS), eff(2,T_TR), red(2,T_TR), ins(2,T_TR), matu(NS,6)
        real(8) :: CaggZ(T_TR), NaggZ(T_TR)

        REAL(8), ALLOCATABLE, DIMENSION(:,:) :: Qmu, H
        REAL(8), ALLOCATABLE, DIMENSION(:,:,:,:) :: MOM, w
        REAL(8), ALLOCATABLE, DIMENSION(:,:,:) :: wZ, momZ, CEphi, CEgamma, CEdelta, CElambda

        !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
        ! Economy A: Initial Steady State
        !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
        
        ! Compute transition matrix for (a,z)
        do is = 1, NS                                                               ! loop over all individual states
            anp_vec = min(a_pol_init(is),amax)                                      ! policy function assets
            aind_np_vec = vsearchprm_FET(anp_vec,a_vec,aprm)                        ! index policy function assets
            aind_np(is) = aind_np_vec(1)                                            ! index policy function assets
            !Sind_np(is) = aind_np(is) + Na*( xind(is) - 1)                           
        enddo
        da = (a_pol_init - a_vec(aind_np))/(a_vec(aind_np+1) - a_vec(aind_np))      ! for linear interpolation on asset grid

        ALLOCATE (Qmu(NS,NS))                                                       ! allocate transition matrix for (a,z)

        Qmu = 0d0                                                                   ! initialize transition matrix
        do is = 1, NS                                                               ! loop over all individual states
        do ix_np = 1, Nx                                                            ! loop over potential future productivities
            is_np = aind_np(is) + Na*( ix_np - 1)                                   ! index for next period's (a,z)
            Qmu(is,is_np)   = (1.0D0-da(is)) * Px(xind(is),ix_np)                   ! put mass at right point in transition matrix
            Qmu(is,is_np+1) =     da(is)     * Px(xind(is),ix_np)                   ! put mass at right point in transition matrix
        end do
        end do

        ! Compute consumption equivalents (see notes on this)
        u_aux = ((c_pol_init**(1d0-sg))/(1-sg))                             ! auxiliary utility from consumption matrix
        err = 1d0                                                           ! initialize error
        C_aux = 0d0                                                         ! initialize auxiliary consumption object

        do iter = 1, 1000
            C_aux_new = u_aux + beta*matmul(Qmu,C_aux)                      ! compute new auxiliary consumption object
            err = norm2(C_aux - C_aux_new)                                  ! compute error
            C_aux = C_aux_new                                               ! update
            if (iter == 1000) then                                          ! break when computation fails
                print *, 'Problem when computing CE'
                STOP
            endif
            if (err<1e-8) then                                              ! check convergence
                exit
            endif
        enddo

        CE = -1d0 + (1d0 + (V_TR(:,1)-V_init)/C_aux )**(1d0/(1d0-sg))       ! compute consumption equivalents distribution
        print *, 'AGG CE ', sum(mu_init*CE)

        ! Compute moments as in BEGS
        ! mom (is, var, period, model)
        ALLOCATE (MOM(NS,6,T_TR,2))
        ALLOCATE (H(NS,NS))

        ! Initial steady state: consumption, labor, log consumption, log labor, log c squared, log h squared
        tt=1
        matu(:,1) = c_pol_init
        matu(:,2) = h_pol_init
        matu(:,3) = log(c_pol_init)
        matu(:,4) = log(h_pol_init)
        matu(:,5) = (log(c_pol_init))**2D0
        matu(:,6) = (log(h_pol_init))**2D0

        ! Put in MOM matrix
        do iter = 1, 6
            MOM(:,iter,tt,1) = matu(:,iter)
        enddo
        
        ! Set matrix H to iterate forward
        H = Qmu

        do tt = 2, T_TR
            print *, 'tt = ', tt
            ! parallelize
            !$OMP PARALLEL DO
            do iter = 1, 6
                MOM(:,iter,tt,1) = matmul(H,matu(:,iter))       ! moments for periods 2 to T_TR
            enddo
            !$OMP END PARALLEL DO
            ! end para
            H = matmul(H,Qmu)                                   ! update matrix H
        enddo
        
        !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
        ! Economy B: Transition
        !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

        ! First transition period: consumption, labor, log consumption, log labor, log c squared, log h squared
        tt=1
        matu(:,1) = c_pol_tr(:,tt)
        matu(:,2) = h_pol_tr(:,tt)
        matu(:,3) = log(c_pol_tr(:,tt))
        matu(:,4) = log(h_pol_tr(:,tt))
        matu(:,5) = (log(c_pol_tr(:,tt)))**2D0
        matu(:,6) = (log(h_pol_tr(:,tt)))**2D0

        ! Put in MOM matrix
        do iter = 1, 6
            MOM(:,iter,tt,2) = matu(:,iter)
        enddo

        ! Compute transition matrix for (a,z)
        do is = 1, NS
            anp_vec = min(a_pol_tr(is,tt),amax)                                         ! policy function assets
            aind_np_vec = vsearchprm_FET(anp_vec,a_vec,aprm)                            ! index policy function assets
            aind_np(is) = aind_np_vec(1)                                                ! index policy function assets
            !Sind_np(is) = aind_np(is) + Na*( xind(is) - 1)
        enddo
        da = (a_pol_tr(:,tt) - a_vec(aind_np))/(a_vec(aind_np+1) - a_vec(aind_np))      ! for linear interpolation in asset dimension
        Qmu = 0d0                                                                       ! initialize transition matrix for (a,z)
        do is = 1, NS                                                                   ! loop over all individual state
        do ix_np = 1, Nx                                                                ! loop over all possible future productivity realizations
            is_np = aind_np(is) + Na*( ix_np - 1)                                       ! index for next period's (a,z)                 
            Qmu(is,is_np)   = (1.0D0-da(is)) * Px(xind(is),ix_np)                       ! put mass at right point in transition matrix
            Qmu(is,is_np+1) =     da(is)     * Px(xind(is),ix_np)                       ! put mass at right point in transition matrix
        end do
        end do

        ! Set matrix H
        H = Qmu


        do tt = 2, T_TR
            print *, 'tt = ', tt
            
            ! Update policies along transition
            matu(:,1) = c_pol_tr(:,tt)
            matu(:,2) = h_pol_tr(:,tt)
            matu(:,3) = log(c_pol_tr(:,tt))
            matu(:,4) = log(h_pol_tr(:,tt))
            matu(:,5) = (log(c_pol_tr(:,tt)))**2D0
            matu(:,6) = (log(h_pol_tr(:,tt)))**2D0

            ! parallelize
            !$OMP PARALLEL DO
            do iter = 1, 6
                MOM(:,iter,tt,2) = matmul(H,matu(:,iter))       ! moments for periods 2 to T_TR
            enddo
            !$OMP END PARALLEL DO
            ! end para

            ! Update Qmu
            do is = 1, NS
                anp_vec = min(a_pol_tr(is,tt),amax)
                aind_np_vec = vsearchprm_FET(anp_vec,a_vec,aprm)
                aind_np(is) = aind_np_vec(1)
                !Sind_np(is) = aind_np(is) + Na*( xind(is) - 1)
            enddo
            da = (a_pol_tr(:,tt) - a_vec(aind_np))/(a_vec(aind_np+1) - a_vec(aind_np))
            Qmu = 0d0
            do is = 1, NS
            do ix_np = 1, Nx
                is_np = aind_np(is) + Na*( ix_np - 1)
                Qmu(is,is_np)   = (1.0D0-da(is)) * Px(xind(is),ix_np)
                Qmu(is,is_np+1) =     da(is)     * Px(xind(is),ix_np)
            end do
            end do

            ! update H
            H = matmul(H,Qmu)

        enddo

        open(1,  file = 'Output/MOMload.txt',  status = 'unknown')
        write(1, *) MOM
        close(1)

        !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
        ! Compute efficiency, redistribution, and insurance terms
        !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
        
        ! matrices of 2 (consumption, labor) times T_TR dimension
        eff = 0d0
        red = 0d0
        ins = 0d0

        ! Allocate objects to compute these terms
        ALLOCATE (w(NS,2,T_TR,2),wZ(NS,2,T_TR),momZ(NS,2,T_TR),CEphi(NS,2,T_TR),CEgamma(NS,2,T_TR),CEdelta(NS,2,T_TR),CElambda(NS,2,T_TR))

        Cagg_init = sum(c_pol_init*mu_init)
        Hagg_init = sum(h_pol_init*mu_init)
        do tt = 1, T_TR

            Cagg_TR(tt) = sum(mu_tr(:,tt)*c_pol_TR(:,tt))                       ! aggregate consumption along transition
            Hagg_TR(tt) = sum(mu_tr(:,tt)*h_pol_TR(:,tt))                       ! aggregate hours along transition

            CaggZ(tt) = sqrt(Cagg_init*Cagg_TR(tt))                             ! C_t^Z in notes
            NaggZ(tt) = sqrt(Hagg_init*Hagg_TR(tt))                             ! N_t^Z in notes

            w(:,1,tt,1) =  MOM(:,1,tt,1)/(sum(mu_init* MOM(:,1,tt,1)))          ! w_{x,t}^j with x = consumption, j = steady state
            w(:,1,tt,2) =  MOM(:,1,tt,2)/(sum(mu_init* MOM(:,1,tt,2)))          ! w_{x,t}^j with x = consumption, j = transition
            wZ(:,1,tt) = sqrt( w(:,1,tt,1)* w(:,1,tt,2) )                       ! w_{x,t}^Z with x = consumption

            momZ(:,1,tt) = CaggZ(tt)*wZ(:,1,tt)                                 ! c_t^z
            CEphi(:,1,tt) = (beta**(tt-1))*(momZ(:,1,tt))**(1d0-sg)             ! phi_{c,t}

            w(:,2,tt,1) =  MOM(:,2,tt,1)/(sum(mu_init* MOM(:,2,tt,1)))          ! w_{x,t}^j with x = labor, j = steady state
            w(:,2,tt,2) =  MOM(:,2,tt,2)/(sum(mu_init* MOM(:,2,tt,2)))          ! w_{x,t}^j with x = labor, j = transition
            wZ(:,2,tt) = sqrt( w(:,2,tt,1)* w(:,2,tt,2) )                       ! w_{x,t}^Z with x = labor
            
            momZ(:,2,tt) = NaggZ(tt)*wZ(:,2,tt)                                 ! n_t^z
            CEphi(:,2,tt) = -(beta**(tt-1))*B*(momZ(:,2,tt))**(1d0+varphi)      ! phi_{n,t}   

            CEgamma(:,1,tt) = log(Cagg_TR(tt))-log(Cagg_init)                   ! Gamma_{c,t}
            CEgamma(:,2,tt) = log(Hagg_TR(tt))-log(Hagg_init)                   ! Gamma_{n,t}

            CEDelta(:,1,tt) = log(w(:,1,tt,2))-log(w(:,1,tt,1))                 ! Delta_{x,t} with x = consumption
            CEDelta(:,2,tt) = log(w(:,2,tt,2))-log(w(:,2,tt,1))                 ! Delta_{x,t} with x = labor

            CELambda(:,1,tt) = -( (MOM(:,5,tt,2) - (MOM(:,3,tt,2)**2d0)) - (MOM(:,5,tt,1) - (MOM(:,3,tt,1)**2d0))       )/2d0   ! Lambda_{x,t} with x = consumption
            CELambda(:,2,tt) = -( (MOM(:,6,tt,2) - (MOM(:,4,tt,2)**2d0)) - (MOM(:,6,tt,1) - (MOM(:,4,tt,1)**2d0))       )/2d0   ! Lambda_{x,t} with x = labor


            eff(1,tt) =sum(mu_init* (CEphi(:,1,tt)*CEgamma(:,1,tt)))            ! efficiency consumption
            eff(2,tt) =sum(mu_init* (CEphi(:,2,tt)*CEgamma(:,2,tt)))            ! efficiency labor

            red(1,tt) =sum(mu_init* (CEphi(:,1,tt)*CEDelta(:,1,tt)))            ! redistribution consumption
            red(2,tt) =sum(mu_init* (CEphi(:,2,tt)*CEDelta(:,2,tt)))            ! redistribution labor

            ins(1,tt) = sum(mu_init * sg * (CEphi(:,1,tt)*CELambda(:,1,tt)))    ! insurance consumption
            ins(2,tt) =-sum(mu_init*varphi*(CEphi(:,2,tt)*CELambda(:,2,tt)))    ! insurance labor

        enddo

        open(1,  file = 'Output/CE_eff.txt',  status = 'unknown')
        open(2,  file = 'Output/CE_red.txt',  status = 'unknown')
        open(3,  file = 'Output/CE_ins.txt',  status = 'unknown')
        write(1, *) eff
        write(2, *) red
        write(3, *) ins
        close(1)
        close(2)
        close(3)

    end subroutine Compute_CE
    
end module ModuleBEGS